In [ ]:
# Imports
import os
import torch
from tqdm.notebook import tqdm
import fancy_einsum as fancy
from mech_interp.fixTL import make_official
from mech_interp.visualizations import (
    plot_board_log_probs,
    map_token_to_move_index,
    preprocess_offset_mapping)
from mech_interp.utils import pretty_moves
from transformer_lens import HookedTransformer
import pandas as pd
from austin_plotly import imshow
import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
import chess
import plotly 

torch.set_grad_enabled(False)
plotly.offline.init_notebook_mode()
# plotly.io.renderers.default = 'notebook' 
In [ ]:
HEIGHT=800
WIDTH=1200

def plot_combined_probe_output(output: torch.Tensor, clip_size, titles, df_data, return_fig = False):
    if output.device != 'cpu':
        output = output.cpu()

    tensor_nps = [np.clip(output[..., i].squeeze().flip(-1).numpy(), -clip_size, clip_size) for i in range(len(titles))]

    # Determine the number of frames (should be same for all tensors)
    num_frames = tensor_nps[0].shape[0]

    # Create subplots
    fig = make_subplots(rows=1, cols=len(titles))

    # Create frames for animation
    frames = []
    for frame_idx in range(num_frames):
        frame_data = []
        for i, tensor_np in enumerate(tensor_nps):
            figi = go.Heatmap(z=tensor_np[frame_idx, :, :], xaxis=f'x{i+1}', yaxis=f'y{i+1}')
            frame_data.append(figi)
        frames.append(go.Frame(data=frame_data, name=str(frame_idx)))



    # Add initial data and layout for each subplot
    for i, (title, tensor_np) in enumerate(zip(titles, tensor_nps)):
        fig.add_trace(
            go.Heatmap(z=tensor_np[0, :, :], xaxis=f'x{i+1}', yaxis=f'y{i+1}',
                        colorscale=[
                            [0, 'blue'],  # Low values
                            [0.5, 'white'],  # Midpoint at 0.0
                            [1, 'red']  # High values
                        ],),
            row=1, col=i+1,
            
        )
        # fig.update_yaxes(scaleanchor=f"x{i+1}", scaleratio=1, row=1, col=i+1)
        fig.update_layout(**{f'xaxis{i+1}': dict(anchor=f'y{i+1}', title=title)})

    # Update layout with slider
    fig.update_layout(
        sliders=[dict(
            steps=[dict(method='animate',
                        args=[[str(frame_idx)], 
                              {"frame": {"duration": 0, "redraw": True},
                               "mode": "immediate",
                               "transition": {"duration": 0}}],
                        label=str(frame_idx))
                    for frame_idx in range(num_frames)],
            active=0,
            transition={"duration": 0},
            x=0, y=0,
            currentvalue={"prefix": "Token: ", "visible": True},
            len=1
        )],
    )
    
    fig.frames=frames
    
    if df_data is not None:
        preprocessed_offsets = preprocess_offset_mapping(df_data['offsets'])
        fen_stack = df_data['fen_stack']
        move_indices = [map_token_to_move_index(df_data['transcript'], i + 1, preprocessed_offsets, zero_index=True)
                        for i in range(len(fig.frames) - 1)]
        heatmaps_texts = []
        for move_idx in move_indices:
            fen = fen_stack[move_idx]
            board = chess.Board(fen)
            strip_all = lambda s: s.replace(' ', '').replace('\n', '')
            text = [f"<span style='font-size: 12em; align: center;'>{c}</span>" 
                    for c in strip_all(board.__str__())] #.unicode(empty_square='·')
            heatmaps_texts.append(np.array(text).reshape(8, 8))
        
        # Add the text to the initial heatmap traces
        for i in range(len(titles)):
            fig.data[i]['text'] = heatmaps_texts[0]
            fig.data[i]['texttemplate'] = "%{text}"

        # Add the text to the frames
        for idx, frame in enumerate(fig.frames):
            if idx < len(heatmaps_texts):
                frame.data[0]["text"] = heatmaps_texts[idx]
                frame.data[0]["texttemplate"] = "%{text}"
                frame.data[1]["text"] = heatmaps_texts[idx]
                frame.data[1]["texttemplate"] = "%{text}"
                frame.data[2]["text"] = heatmaps_texts[idx]
                frame.data[2]["texttemplate"] = "%{text}"

    fig.update_layout(
        height=HEIGHT,  # Set the desired height
        width=WIDTH,   # Set the desired width
        margin=dict(l=20, r=20, t=50, b=20),  # Adjust margins as needed
        grid=dict(rows=1, columns=len(titles), pattern='independent'),  # Configure the grid for subplots
        # Include any other layout configurations here
    )
    
    if return_fig:
        return fig
    else:
        fig.show()
In [ ]:
model = HookedTransformer.from_pretrained(make_official())

DATA_DIR = '../chess_data/lichess_test.pkl'
df = pd.read_pickle(DATA_DIR)
print(df.keys())
input_ids = torch.tensor(df['input_ids'][0:15].tolist())
logits, cache = model.run_with_cache(input_ids)
Using eos_token, but it is not set yet.
Loaded pretrained model AustinD/gpt2-chess-uci-hooked into HookedTransformer
Index(['WhiteElo', 'BlackElo', 'Result', 'complete_transcript', 'input_ids',
       'offsets', 'transcript', 'WhiteEloBinned', 'WhiteEloBinIndex',
       'BlackEloBinned', 'BlackEloBinIndex', 'fen_stack'],
      dtype='object')

Next, I'll load the probes and create a figure for each one showing the output for 'white' 'black' 'empty' predictions.

In [ ]:
PROBE_NUMBER = 6

PROBE_NAME = f'../linear_probes/saved_probes/color_probe_gpt2-chess-uci-hooked_layer_{PROBE_NUMBER}_indexing_df_to_color_state.pth'

probe = torch.load(PROBE_NAME)
print(probe.keys())
linear_probe = probe['linear_probe']

resids = cache['resid_post', PROBE_NUMBER]
probe_output = fancy.einsum(
    "batch pos d_model, d_model rows cols classes -> batch pos rows cols classes",
    resids,
    linear_probe,
)
print(probe_output.shape, ' ', probe_output[2,...,1].squeeze().shape, sep = '\n')
dict_keys(['acc', 'loss', 'lr', 'epoch', 'batch', 'linear_probe', 'linear_probe_name', 'model_name', 'layer', 'indexing_function_name', 'batch_size', 'wd', 'split', 'num_epochs', 'num_classes', 'wandb_project', 'wandb_run_name', 'dataset_prefix'])
torch.Size([15, 126, 8, 8, 3])
 
torch.Size([126, 8, 8])

Note on Large he text in the plotly cells gets clipped. This can be solved by using a larger figure. HTML viewers should zoom out to view all 3 panels.

In [ ]:
clip = 0.4
game_id = 13
HEIGHT = 450
WIDTH = 1200
SAVED_PROBE_DIR = '../linear_probes/saved_probes/'

figures = []

for file in tqdm(sorted(os.listdir(SAVED_PROBE_DIR)), 'Building probe figures'):
    title = ' | '.join([probe['indexing_function_name']]+[f'{k}:{probe.get(k):.5f}' for k in ['acc', 'loss', 'lr', 'epoch', 'layer']])
    
    probe = torch.load(SAVED_PROBE_DIR+file)
    linear_probe = probe['linear_probe']
    target_layer = int(probe['layer'])

    resids = cache['resid_post', target_layer]
    probe_output = fancy.einsum(
        "batch pos d_model, d_model rows cols classes -> batch pos rows cols classes",
        resids,
        linear_probe,
    )
    
    fig = plot_combined_probe_output(probe_output[game_id].squeeze(), clip, ['WHITE','BLACK','EMPTY'], df.iloc[game_id], return_fig=True)
    fig.update_layout(title=title)
    figures.append(fig)
Building probe figures:   0%|          | 0/13 [00:00<?, ?it/s]
In [ ]:
print(f"Game {game_id} transcript:\n{pretty_moves(df.iloc[game_id]['transcript'])}")

display(figures[0])
Game 13 transcript:
Moves for  White  vs  Black 
 d2d4 g8f6 e2e3 g7g6 f1d3 f8g7 g1f3 e8g8 e1g1 d7d6
 c2c4 c7c5 b1c3 b8d7 h2h3 a7a6 c1d2 a8b8 d1e2 b7b6
 e3e4 c8b7 d4d5 e7e5 d5e6 f7e6 a1d1 d8c7 d2f4 f6e8
 f4g3 g7d4 f3d4 c5d4 c3b1 e6e5 b1d2 d7c5 d3b1 a6a5
 f2f4 e8g7 d1e1 b8e8 e2g4 c7c8 f4f5 g8h8 d2b3 g6f5
 e4f5 e5e4 b3d4 c5d3 b1d3 e4d3 d4e6 g7e6 f5e6 c8c6
 g3h4 c6c5 g1
In [ ]: